from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

# 定义数据集
dataset = MNIST(root='', train=True, transform=ToTensor())

# 显示多个图像
fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(10, 5))
for i, axi in enumerate(ax.flat):
    image, label = dataset[i]
    image = image.squeeze().numpy()
    axi.imshow(image, cmap='gray')
    axi.set(title=f"Label: {label}")
    axi.axis('off')
plt.show()

